import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

from model import zoo as models
from model.raymarcher_lebesgue import ImplicitRendererDict
from model import raymarcher_simple as raymarchers
from data.loader import load_dataset
from data import utils as collate_fns

from pytorch3d.renderer import (
    NDCGridRaysampler,
    MonteCarloRaysampler,
    EmissionAbsorptionRaymarcher,
    ImplicitRenderer
)
from trainers.utils import sample_images_at_mc_locs, huber


class LitModel(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        hparams.model.density.max_depth = hparams.data_conf.max_depth
        self.save_hyperparameters(hparams)
        self.debug = hparams.get('debug', False)

        Model = getattr(models, self.hparams.get('model_name', 'NeuralRadianceField'), 'NeuralRadianceField')

        self.nerf = Model(
            self.hparams.model
        )
            
        # 1) Instantiate the raysamplers.
        # Here, NDCGridRaysampler generates a rectangular image
        # grid of rays whose coordinates follow the PyTorch3D
        # coordinate conventions.
        self.raysampler_grid = NDCGridRaysampler(
            image_height=self.hparams.data_conf.render_height,
            image_width=self.hparams.data_conf.render_width,
            n_pts_per_ray=self.hparams.data.n_pts_per_ray,
            min_depth=self.hparams.data_conf.get('min_depth', 0),
            max_depth=self.hparams.data_conf.max_depth,
        )

        # MonteCarloRaysampler generates a random subset 
        # of `n_rays_per_image` rays emitted from the image plane.
        self.raysampler_mc = MonteCarloRaysampler(
            min_x = -1.0,
            max_x = 1.0,
            min_y = -1.0,
            max_y = 1.0,
            n_rays_per_image=self.hparams.train.n_rays_per_image,
            n_pts_per_ray=self.hparams.train.n_pts_per_ray,
            min_depth=self.hparams.data_conf.get('min_depth', 0),
            max_depth=self.hparams.data_conf.max_depth,
        )

        # 2) Instantiate the raymarcher.
        # Here, we use the standard EmissionAbsorptionRaymarcher 
        # which marches along each ray in order to render
        # the ray into a single 3D color vector 
        # and an opacity scalar.
        self.raymarcher = getattr(raymarchers, self.hparams.model.get('raymarcher', 'EmissionAbsorptionDictRaymarcher'))()

        # Finally, instantiate the implicit renders
        # for both raysamplers.
        self.renderer_grid = ImplicitRendererDict(
            raysampler=self.raysampler_grid, raymarcher=self.raymarcher, stratified_resamling=self.hparams.train.get('stratified_sampling', False)
        )

        self.renderer_mc = ImplicitRendererDict(
            raysampler=self.raysampler_mc, raymarcher=self.raymarcher, stratified_resamling=self.hparams.train.get('stratified_sampling', False)
        )

    def train_dataloader(self):
        epoch_len = self.hparams.train.epoch_len
        if self.debug:
            epoch_len = 1
        self.train_dataset = load_dataset(self.hparams.data_conf, split='train', epoch_len=epoch_len)
        collate_fn = getattr(collate_fns, self.hparams.data_conf.collate_fn)
        return torch.utils.data.DataLoader(self.train_dataset,
                          shuffle=True,
                          num_workers=self.hparams.train.num_workers,
                          collate_fn=collate_fn,
                          batch_size=self.hparams.train.batch_size,
                          pin_memory=True)

    def val_dataloader(self):
        if self.debug:
            self.val_dataset = load_dataset(self.hparams.data_conf, split='train', epoch_len=1)
        else:
            self.val_dataset = load_dataset(self.hparams.data_conf, split='val')
        collate_fn = getattr(collate_fns, self.hparams.data_conf.collate_fn)
        return torch.utils.data.DataLoader(self.val_dataset,
                          shuffle=False,
                          num_workers=self.hparams.train.num_workers,
                          collate_fn=collate_fn,
                          batch_size=1,
                          pin_memory=True)

    def forward(self, x):
        return self.nerf(x)

    def configure_optimizers(self):
        # optim = torch.optim.Adam(self.parameters(), lr=self.hparams.train.learning_rate)
        color_parameters = [p for name, p in self.named_parameters() if 'density' not in name]
        density_parameters = [p for name, p in self.named_parameters() if 'density' in name]
        density_lr = self.hparams.train.get('init_density_lr', self.hparams.train.learning_rate)
        Optimizer = getattr(torch.optim, self.hparams.train.get('optim', 'Adam'))
        optim = Optimizer([
            {'params': color_parameters}, 
            {'params': density_parameters,'lr': density_lr}
            ], lr=self.hparams.train.learning_rate, weight_decay=self.hparams.train.get('weight_decay', 0.0)
        )
        return optim


    def on_train_epoch_start(self):
        if (self.hparams.train.get('scheduler', None) is None) and (
                self.current_epoch == round(self.hparams.train.num_epochs * 0.75)):
            print('Decreasing LR 10-fold ...')
            for g in self.trainer.optimizers[0].param_groups:
                g['lr'] = round(g['lr'] * 0.1)
        self.on_train_step_start()
    
    def on_train_step_start(self):
        def schedule_lr(step, base_lr, final_lr):
            total_steps = self.hparams.train.num_epochs * len(self.train_dataset)
            lambda_w = self.hparams.train.scheduler.lambda_w
            warmup_steps = self.hparams.train.scheduler.warmup_epochs * len(self.train_dataset)
            warmup_mult = (lambda_w + (1 - lambda_w) * math.sin((math.pi / 2) * max(min((step / warmup_steps), 1), 0)))
            eta_i = warmup_mult * (math.exp((1 - step / total_steps) * math.log(base_lr) + (step / total_steps) * math.log(final_lr)))
            return eta_i
    
        if self.hparams.train.get('scheduler', None) is not None:
            # adjust learning rate there
            # for instance we can do it only for param group 0 - color mlp
            # self.trainer.optimizers[0].param_groups[0]['lr'] = self.hparams.train.learning_rate
            self.trainer.optimizers[0].param_groups[0]['lr'] = schedule_lr(
                    self.trainer.global_step, self.hparams.train.learning_rate, self.hparams.train.scheduler.final_lr)
            if self.hparams.train.scheduler.get('update_density_lr', True):
                self.trainer.optimizers[0].param_groups[1]['lr'] = schedule_lr(
                    self.trainer.global_step, self.hparams.train.get('init_density_lr', self.hparams.train.learning_rate), 
                    self.hparams.train.scheduler.get('final_density_lr', self.hparams.train.scheduler.final_lr)
                )

    def training_step(self, batch, batch_idx):
        # Evaluate the nerf model.
        rendered_images_silhouettes, sampled_rays, _ = self.renderer_mc(
            cameras=batch['cameras'], 
            volumetric_function=self.nerf
        )
        if 'target_silhouettes' in batch and self.hparams.train.get('silhouette_weight', 1) > 0:
            rendered_images, rendered_silhouettes = (
                rendered_images_silhouettes.split([3, 1], dim=-1)
            )
            
            # Compute the silhouette error as the mean huber
            # loss between the predicted masks and the
            # sampled target silhouettes.
            silhouettes_at_rays = sample_images_at_mc_locs(
                batch['target_silhouettes'][..., None], 
                sampled_rays.xys
            )
            sil_err = huber(
                rendered_silhouettes, 
                silhouettes_at_rays,
            ).abs().mean()
        else:
            sil_err = None
            if rendered_images_silhouettes.shape[-1] == 4:
                rendered_images, _ = (
                    rendered_images_silhouettes.split([3, 1], dim=-1)
                )
            else:
                rendered_images = rendered_images_silhouettes

        # Compute the color error as the mean huber
        # loss between the rendered colors and the
        # sampled target images.
        colors_at_rays = sample_images_at_mc_locs(
            batch['target_images'], 
            sampled_rays.xys
        )
        if self.hparams.train.get('add_noise', False):
            colors_at_rays = colors_at_rays + torch.rand_like(colors_at_rays) / 255.0
        color_err = huber(
            rendered_images, 
            colors_at_rays,
        ).abs().mean()

        # The optimization loss is a simple
        # sum of the color and silhouette errors.
        loss = color_err
        self.log('train/color_err', color_err.item())
        if sil_err is not None:
            loss += self.hparams.train.get('silhouette_weight', 1) * sil_err
            self.log('train/sil_err', sil_err.item())
        loss = loss + self.l1_l2_reg()
        self.log('train/loss_total', loss.item())
        return loss
    

    def l1_l2_reg(self):
        loss = 0.
        # L1 regularizer
        if self.hparams.train.get('l1_weight', 0) > 0:
            parameters = []
            if self.hparams.model.get('color', {}).get('linear', True):
                parameters += [p.view(-1, 1) for name, p in self.nerf.named_parameters() if p.requires_grad and 'color' in name]
            if self.hparams.model.get('density', {}).get('linear', True):
                parameters += [p.view(-1, 1) for name, p in self.nerf.named_parameters() if p.requires_grad and 'density' in name]
            parameters = torch.cat(parameters)
            l1_reg = parameters.abs().sum()
            loss += self.hparams.train.get('l1_weight', 0) * l1_reg
            self.log('train/l1_reg', l1_reg)

        # L2 regularizer
        if self.hparams.train.get('l2_weight', 0) > 0:
            parameters = []
            if self.hparams.model.get('color', {}).get('linear', True):
                parameters += [p.view(-1, 1) for name, p in self.nerf.named_parameters() if p.requires_grad and 'color' in name]
            if self.hparams.model.get('density', {}).get('linear', True):
                parameters += [p.view(-1, 1) for name, p in self.nerf.named_parameters() if p.requires_grad and 'density' in name]
            parameters = torch.cat(parameters)
            l2_reg = parameters.pow(2).sum()
            loss += self.hparams.train.get('l2_weight', 0) * l2_reg
            self.log('train/l2_reg', l2_reg)
        return loss


    def validation_step(self, batch, batch_idx):
        frame = self.renderer_grid(
                cameras=batch['cameras'], 
                volumetric_function=self.nerf.batched_forward,
            )[0][..., :3] # batch is supposed to be 1 for validation
        if 'target_images' in batch and batch['target_images'] is not None: # if we have target for this view angle
            pred_interpolated = F.interpolate(frame.permute(0, 3, 1, 2).clamp(0., 1.), size=batch['target_images'].shape[1:3], mode='bilinear', align_corners=True).permute(0, 2, 3, 1)
            mse = ((pred_interpolated - batch['target_images']) ** 2).mean()
            psnr = 10 * torch.log10(1.0/mse).item()
            if self.debug:
                print('psnr:', psnr)
        else:
            psnr = None
        return {
            'frame': frame,
            'psnr': psnr
        }
    

    def validation_epoch_end(self, outputs):
        image_grid = torchvision.utils.make_grid(torch.cat([item['frame'] for item in outputs]).clamp(0., 1.).permute(0, 3, 1, 2), nrow=3)
        self.log('val/psnr', np.mean([item['psnr'] for item in outputs if item['psnr'] is not None]), on_epoch=True)
        self.logger.log_image('val/image_grid', [image_grid])